{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 02 scikit-learn 中的 kNN "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "raw_data_X = [[3.393533211, 2.331273381],\n",
    "              [3.110073483, 1.781539638],\n",
    "              [1.343808831, 3.368360954],\n",
    "              [3.582294042, 4.679179110],\n",
    "              [2.280362439, 2.866990263],\n",
    "              [7.423436942, 4.696522875],\n",
    "              [5.745051997, 3.533989803],\n",
    "              [9.172168622, 2.511101045],\n",
    "              [7.792783481, 3.424088941],\n",
    "              [7.939820817, 0.791637231]\n",
    "             ]\n",
    "raw_data_y = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]\n",
    "\n",
    "X_train = np.array(raw_data_X)\n",
    "y_train = np.array(raw_data_y)\n",
    "\n",
    "x = np.array([8.093607318, 3.365731514])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%run kNN_function/kNN.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "predict_y = kNN_classify(6, X_train, y_train, x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "1"
     },
     "metadata": {},
     "execution_count": 30
    }
   ],
   "source": [
    "predict_y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 使用scikit-learn中的kNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# sklearn中都是使用面向对象封装的\n",
    "from sklearn.neighbors import KNeighborsClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 传入k值\n",
    "kNN_classifier = KNeighborsClassifier(n_neighbors=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n                     metric_params=None, n_jobs=None, n_neighbors=6, p=2,\n                     weights='uniform')"
     },
     "metadata": {},
     "execution_count": 33
    }
   ],
   "source": [
    "# fit，拟合过程\n",
    "# 返回的是自己本身\n",
    "# 此时这个实例中已经有模型了\n",
    "kNN_classifier.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "array([1])"
     },
     "metadata": {},
     "execution_count": 34
    }
   ],
   "source": [
    "# kNN_classifier.predict(x)   # 出错了，必须传入一个2D的array\n",
    "kNN_classifier.predict(x.reshape((1, -1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "X_predict = x.reshape(1, -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "array([[8.09360732, 3.36573151]])"
     },
     "metadata": {},
     "execution_count": 36
    }
   ],
   "source": [
    "X_predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "array([1])"
     },
     "metadata": {},
     "execution_count": 37
    }
   ],
   "source": [
    "kNN_classifier.predict(X_predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "y_predict = kNN_classifier.predict(X_predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "1"
     },
     "metadata": {},
     "execution_count": 39
    }
   ],
   "source": [
    "y_predict[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 重新整理我们的kNN的代码"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "代码参见 [这里](kNN/KNN.py)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%run kNN/kNN.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "knn_clf = KNNClassifier(k=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "KNN(k=6)"
     },
     "metadata": {},
     "execution_count": 42
    }
   ],
   "source": [
    "knn_clf.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "y_predict = knn_clf.predict(X_predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "array([1])"
     },
     "metadata": {},
     "execution_count": 44
    }
   ],
   "source": [
    "y_predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "1"
     },
     "metadata": {},
     "execution_count": 45
    }
   ],
   "source": [
    "y_predict[0]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.6.8 64-bit ('anaconda': conda)",
   "language": "python",
   "name": "python36864bitanacondaconda085db2e8b14649f4b32196068f7124e9"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7-final"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}